#%% Importação dos pacotes
import numpy as np
import matplotlib.pyplot as plt
import scipy.io
from sklearn.preprocessing import StandardScaler
from sklearn.decomposition import PCA
from sklearn.cluster import KMeans
from sklearn.metrics import silhouette_samples, silhouette_score
from matplotlib import cm

#%% Carregamento dos dados
matlab_data = scipy.io.loadmat('MACHINE_Data.mat', struct_as_record=False)
Etch_data = matlab_data['LAMDATA']
calibration_dataAll = Etch_data[0,0].calibration
variable_names = Etch_data[0,0].variables

#%% Pré-processamento (Unfolding)
n_vars = variable_names.size - 2
n_samples = 85

unfolded_dataMatrix = np.empty((1, n_vars * n_samples))
for expt in range(calibration_dataAll.size):
    calibration_expt = calibration_dataAll[expt,0][5:90, 2:]
    if calibration_expt.shape[0] < 85:
        continue
    unfolded_row = np.ravel(calibration_expt, order='F')[np.newaxis, :]
    unfolded_dataMatrix = np.vstack((unfolded_dataMatrix, unfolded_row))
unfolded_dataMatrix = unfolded_dataMatrix[1:, :]

#%% Normalização + PCA
scaler = StandardScaler()
data_train_normal = scaler.fit_transform(unfolded_dataMatrix)
pca = PCA(n_components=3)
score_train = pca.fit_transform(data_train_normal)

#%% Plot inicial: sem separação por cluster
plt.figure()
plt.scatter(score_train[:, 0], score_train[:, 1], color='blue')
plt.xlabel('PC1 scores')
plt.ylabel('PC2 scores')
plt.title('Amostras sem separação por cluster')
plt.grid(False)
plt.show()

#%% Loop interativo
while True:
    try:
        n_cluster = int(input("Digite o número de clusters desejado (2-10) ou 0 para ver os gráficos de análise: "))
    except ValueError:
        print("Por favor, insira um número inteiro válido.")
        continue

    if n_cluster == 0:
        # Método do cotovelo e silhueta
        k_values = range(2, 11)
        sse = []
        silhouette_avgs = []

        for k in k_values:
            kmeans = KMeans(n_clusters=k, random_state=100, n_init=10).fit(score_train)
            labels = kmeans.labels_
            sse.append(kmeans.inertia_)
            silhouette_avgs.append(silhouette_score(score_train, labels))

        # Gráfico do cotovelo (SSE)
        plt.figure()
        plt.plot(k_values, sse, marker='o', color='red')
        plt.xlabel('Número de Clusters (k)')
        plt.ylabel('Soma dos Erros Quadráticos (SSE)')
        plt.title('Método do Cotovelo')
        plt.grid(False)
        plt.show()

        # Gráfico de silhueta média
        plt.figure()
        plt.plot(k_values, silhouette_avgs, marker='o', color='blue')
        plt.xlabel('Número de Clusters (k)')
        plt.ylabel('Coeficiente de Silhueta Médio')
        plt.title('Silhueta Média por Número de Clusters')
        plt.grid(False)
        plt.show()

    elif 2 <= n_cluster <= 10:
        # K-Means com número de clusters especificado
        kmeans = KMeans(n_clusters=n_cluster, random_state=100, n_init=10).fit(score_train)
        cluster_labels = kmeans.labels_

        # Visualização com clusters
        plt.figure()
        plt.scatter(score_train[:, 0], score_train[:, 1], c=cluster_labels, cmap='viridis', s=25)
        for i, center in enumerate(kmeans.cluster_centers_):
            plt.scatter(center[0], center[1], c='red', marker='*', s=80)
            plt.annotate(f'Cluster {i+1}', (center[0], center[1]))
        plt.xlabel('PC1 scores')
        plt.ylabel('PC2 scores')
        plt.title(f'K-Means com {n_cluster} clusters')
        plt.grid(False)
        plt.show()

        # Análise de silhueta
        silhouette_vals = silhouette_samples(score_train, cluster_labels)
        silhouette_avg = silhouette_score(score_train, cluster_labels)
        print(f'Coeficiente de silhueta médio: {silhouette_avg:.3f}')

        plt.figure()
        y_lower = 0
        yticks = []
        for i in range(n_cluster):
            ith_cluster_vals = silhouette_vals[cluster_labels == i]
            ith_cluster_vals.sort()
            y_upper = y_lower + len(ith_cluster_vals)
            color = cm.nipy_spectral(float(i) / n_cluster)
            plt.barh(range(y_lower, y_upper), ith_cluster_vals, color=color, edgecolor='none')
            yticks.append((y_lower + y_upper) / 2)
            y_lower = y_upper

        plt.axvline(silhouette_avg, color="red", linestyle="--")
        plt.yticks(yticks, [f'Cluster {i+1}' for i in range(n_cluster)])
        plt.xlabel("Coeficiente de Silhueta")
        plt.ylabel("Cluster")
        plt.title("Análise de Silhueta")
        plt.grid(False)
        plt.show()

    else:
        print("Por favor, digite um número entre 2 e 10 ou 0 para ver os gráficos.")

    # Pergunta de encerramento
    continuar = input("Deseja testar outro número de clusters? (S para sim, N para não): ").strip().upper()
    if continuar == 'N':
        print("Programa encerrado.")
        break